{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8070eae9-13fa-4758-a645-0539faaf6dc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow import keras\n",
    "import tensorflow as tf \n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "5fc30778-ad09-4419-9c53-4be1153e7f89",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5000, 28, 28) (5000,)\n",
      "(55000, 28, 28) (55000,)\n",
      "(10000, 28, 28) (10000,)\n"
     ]
    }
   ],
   "source": [
    "fashion_mnist = keras.datasets.fashion_mnist\n",
    "(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()\n",
    "x_valid, x_train = x_train_all[:5000], x_train_all[5000:]\n",
    "y_valid, y_train = y_train_all[:5000], y_train_all[5000:]\n",
    "print(x_valid.shape, y_valid.shape)\n",
    "print(x_train.shape, y_train.shape)\n",
    "print(x_test.shape, y_test.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "ccb9fce1-c814-49c1-b54a-1a783107ad11",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(55000, 784)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train.astype(np.float32).reshape(-1, 784).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "6135eb5f-d1cc-46ab-b967-ca474369748b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 标准化\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "scaler = StandardScaler()\n",
    "x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(55000, -1))\n",
    "x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(5000, -1))\n",
    "\n",
    "x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(10000, -1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "2e51955e-f3eb-42ed-b8af-f926e5577dba",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_dataset(data, target, epochs, batch_size, shuffle=True):\n",
    "    dataset = tf.data.Dataset.from_tensor_slices((data, target))\n",
    "    if shuffle:\n",
    "        dataset = dataset.shuffle(10000)\n",
    "    dataset = dataset.repeat(epochs).batch(batch_size).prefetch(50)\n",
    "    return dataset\n",
    "\n",
    "batch_size = 64\n",
    "epochs = 20\n",
    "train_dataset = make_dataset(x_train_scaled, y_train, epochs, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "4730413a-1782-4759-aaa0-c5fa6d733289",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\developer\\python396\\lib\\site-packages\\keras\\src\\layers\\core\\dense.py:87: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n",
      "  super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n"
     ]
    }
   ],
   "source": [
    "model = keras.models.Sequential()\n",
    "model.add(keras.layers.Dense(512, activation='relu', input_shape=(784,)))\n",
    "model.add(keras.layers.Dense(256, activation='relu'))\n",
    "model.add(keras.layers.Dense(10, activation='softmax'))\n",
    "\n",
    "model.compile(loss='sparse_categorical_crossentropy',\n",
    "             optimizer='adam',\n",
    "             metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "a44fd854-32e9-4ef7-8255-e9f30fc85e8f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Model: \"sequential\"</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mModel: \"sequential\"\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\"> Layer (type)                         </span>┃<span style=\"font-weight: bold\"> Output Shape                </span>┃<span style=\"font-weight: bold\">         Param # </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n",
       "│ dense (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)                        │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">512</span>)                 │         <span style=\"color: #00af00; text-decoration-color: #00af00\">401,920</span> │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ dense_1 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)                      │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">256</span>)                 │         <span style=\"color: #00af00; text-decoration-color: #00af00\">131,328</span> │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ dense_2 (<span style=\"color: #0087ff; text-decoration-color: #0087ff\">Dense</span>)                      │ (<span style=\"color: #00d7ff; text-decoration-color: #00d7ff\">None</span>, <span style=\"color: #00af00; text-decoration-color: #00af00\">10</span>)                  │           <span style=\"color: #00af00; text-decoration-color: #00af00\">2,570</span> │\n",
       "└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1mLayer (type)                        \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mOutput Shape               \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m        Param #\u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩\n",
       "│ dense (\u001b[38;5;33mDense\u001b[0m)                        │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m512\u001b[0m)                 │         \u001b[38;5;34m401,920\u001b[0m │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ dense_1 (\u001b[38;5;33mDense\u001b[0m)                      │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m256\u001b[0m)                 │         \u001b[38;5;34m131,328\u001b[0m │\n",
       "├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤\n",
       "│ dense_2 (\u001b[38;5;33mDense\u001b[0m)                      │ (\u001b[38;5;45mNone\u001b[0m, \u001b[38;5;34m10\u001b[0m)                  │           \u001b[38;5;34m2,570\u001b[0m │\n",
       "└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Total params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">535,818</span> (2.04 MB)\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m Total params: \u001b[0m\u001b[38;5;34m535,818\u001b[0m (2.04 MB)\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">535,818</span> (2.04 MB)\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m Trainable params: \u001b[0m\u001b[38;5;34m535,818\u001b[0m (2.04 MB)\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\"> Non-trainable params: </span><span style=\"color: #00af00; text-decoration-color: #00af00\">0</span> (0.00 B)\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1m Non-trainable params: \u001b[0m\u001b[38;5;34m0\u001b[0m (0.00 B)\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "1dccde73-0fde-4ad7-a0ee-cc3931a4f235",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_dataset = make_dataset(x_valid_scaled, y_valid, epochs=1, batch_size=32, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "3e46b4ca-0646-4171-b551-055005ca8f9b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/10\n",
      "\u001b[1m859/859\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m10s\u001b[0m 9ms/step - accuracy: 0.8051 - loss: 0.5396 - val_accuracy: 0.8692 - val_loss: 0.3567\n",
      "Epoch 2/10\n",
      "\u001b[1m859/859\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 8ms/step - accuracy: 0.8808 - loss: 0.3260 - val_accuracy: 0.8754 - val_loss: 0.3549\n",
      "Epoch 3/10\n",
      "\u001b[1m859/859\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 9ms/step - accuracy: 0.8912 - loss: 0.2901 - val_accuracy: 0.8756 - val_loss: 0.3372\n",
      "Epoch 4/10\n",
      "\u001b[1m859/859\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 8ms/step - accuracy: 0.9013 - loss: 0.2605 - val_accuracy: 0.8848 - val_loss: 0.3278\n",
      "Epoch 5/10\n",
      "\u001b[1m859/859\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 8ms/step - accuracy: 0.9080 - loss: 0.2400 - val_accuracy: 0.8940 - val_loss: 0.3142\n",
      "Epoch 6/10\n",
      "\u001b[1m859/859\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 9ms/step - accuracy: 0.9177 - loss: 0.2207 - val_accuracy: 0.8842 - val_loss: 0.3332\n",
      "Epoch 7/10\n",
      "\u001b[1m859/859\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 9ms/step - accuracy: 0.9212 - loss: 0.2037 - val_accuracy: 0.8896 - val_loss: 0.3388\n",
      "Epoch 8/10\n",
      "\u001b[1m859/859\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 9ms/step - accuracy: 0.9312 - loss: 0.1830 - val_accuracy: 0.8934 - val_loss: 0.3373\n",
      "Epoch 9/10\n",
      "\u001b[1m859/859\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m8s\u001b[0m 10ms/step - accuracy: 0.9362 - loss: 0.1689 - val_accuracy: 0.8872 - val_loss: 0.3562\n",
      "Epoch 10/10\n",
      "\u001b[1m859/859\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m7s\u001b[0m 9ms/step - accuracy: 0.9386 - loss: 0.1600 - val_accuracy: 0.8970 - val_loss: 0.3699\n"
     ]
    }
   ],
   "source": [
    "history = model.fit(train_dataset, \n",
    "         steps_per_epoch=x_train_scaled.shape[0] // batch_size,\n",
    "         epochs=10,\n",
    "         validation_data=eval_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "848c01b0-9b3c-43c6-b142-38142d3beece",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[1m157/157\u001b[0m \u001b[32m━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[37m\u001b[0m \u001b[1m0s\u001b[0m 2ms/step - accuracy: 0.8976 - loss: 0.3693\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[0.3698732554912567, 0.8970000147819519]"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.evaluate(eval_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "caed9575-4aaa-4b9f-9645-96ee57021e6c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
